Skip to content

NucleusMoE-Image#13317

Open
sippycoder wants to merge 21 commits intohuggingface:mainfrom
sippycoder:main
Open

NucleusMoE-Image#13317
sippycoder wants to merge 21 commits intohuggingface:mainfrom
sippycoder:main

Conversation

@sippycoder
Copy link
Copy Markdown

What does this PR do?

This PR introduces NucleusMoE-Image series into the diffusers library.

NucleusMoE-Image is a 2B active 17B parameter model trained with efficiency at its core. Our novel architecture highlights the scalability of sparse MoE architecture for Image generation. The technical report will be released very soon.

@sippycoder
Copy link
Copy Markdown
Author

cc: @sayakpaul @IlyasMoutawwakil

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu March 24, 2026 04:08
Comment on lines +545 to +546
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird to me that we first clamp the gates to [-2.0, 2.0] and then essentially clamp again by squashing with the tanh function below. Is this intended?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's weird. :) I used it to stabilize the gradients if the tanh gates get saturated while training. I will evaluate the model performance without it and get back to you!

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial review :). @yiyixuxu, could you also take a look at the text KV cache code in src/diffusers/hooks/text_kv_cache.py?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +380 to +391
self.experts = nn.ModuleList(
[
FeedForward(
dim=hidden_size,
dim_out=hidden_size,
inner_dim=moe_intermediate_dim,
activation_fn="swiglu",
bias=False,
)
for _ in range(num_experts)
]
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you would need the projections to be in packed/contiguous format for torch.grouped_mm support (num_experts, dim_in, dim_out), @sayakpaul is that possible ? in Transformers we use the inline weight converter

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment because MoEs are still a bit of a special case in this part of world.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can pack the MoE weights. That's how I originally trained the model with Expert Parallel.

Comment on lines +174 to +175
if max_txt_seq_len is None:
raise ValueError("Either `max_txt_seq_len` must be provided.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to provide a reasonable default value for max_txt_seq_len instead of raising an error?



class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin):
def test_txt_seq_lens_deprecation(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this test now that txt_seq_lens has been removed from the transformer's forward method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants